Keras triplet loss sample

Keras implementation of the triplet loss of [1].

Introduction.

Commonly, a machine learning problem consists of 4 components: data (Section 1.2), a trainings-objective (Section 1.3), a model (a parametrizable function, Section 1.4), and a trainings-procedure (Section 1.5).

After defining the problem, we train the model in Section 2 and then evaluate it in Section 3.

Papers

Resources

Imports


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))



In [2]:
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
import numpy as np
import os
import keras.backend as K
import tensorflow as tf

assert tf.__version__.startswith("1.3"),"Tensorflow that was used for this example"
assert keras.__version__ == "2.0.8","Tensorflow that was used for this example"


Using TensorFlow backend.

1.1 Hyperparameters


In [3]:
batch_size = 100
num_classes = 10 # how many categories there are
embedding_dim = 32 # how many dimensions the embedded space has 
epochs = 30
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'

1.2 Data (train, test)


In [4]:
# The data, shuffled and split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# normalize data
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255


x_train shape: (50000, 32, 32, 3)
50000 train samples
10000 test samples

1.3 Triplet loss function

Some remarks to the implementation(see Section 3.2):

  • As in [1], we focus on online triplet generation . That means we calculate all triplets per one batch.
  • As in [1], we use all anchor-positive pairs.
  • In constrast to [1], we also use all anchor-negative pairs.
  • We don't use any pre-sampling of data classes.

Notation

  • To denote the an embedded image we use $z_i^a$ instead of $f(x_i^a)$

Loss calculation

Calculation of the triplet loss per batch can roughly be divided into five steps:

  • Step 1) Calculation of the pairwise euclidean distances $\|z_i, z_j\|$ for all pairs $i,j$ in the batch of embedded images. The euclidean distances are used in multiple occations in the paper, e.g. Equations (1), (2) and (3) in [1]. The result will be a [batch_size x batchsize] shaped 2d tensor, which contains the pairwise distances. See [4] for more details on the implemnentation and the results.

  • Step 2) Next, we calculate pairwise label equality. The result will be a [batch_size x batchsize] shaped 2d tensor, which contains whether each trainings example is in the same class or not. For example, if the pairwise label equality matrix is $L$, and $L{(3,4)}=1$, this means that image $x_3$ has the same label as $x_4$. We need this pairwais label equality to determine positive and negative pairs.

  • Step 3) Next, we get the euclidean distances from all positive and from all negative examples. We check which is a positive example (using the label equality matrix), and select these examples from the euclidean distance matrix. We set the elements at the diagonal -1 so that distances to an image itself, i.e. anchor to anchor cases are ignored in a later step. The results of step3 are two [batch_size x batch_size] shaped 2d tensors. The positive one contains all $\|z_i^a-z_j^p\|$, whereas the negative one contains all $\|z_i^a-z_j^n\|$ for all $i,j$ pairs in this batch.|

  • Step 4) Get permutations of all possible permutations of triplets per row.

  • Step 5) Use only the ones that violate the triplet constraint to calculate the loss.

Warning:

  • The way this loss function is implemented, the batch size is fixed. Also, this method does not support batches that are not of this exact batch size.

In [5]:
def triplet_loss(y_true, y_pred, alpha=0.8, batch_size=batch_size):
    print("compiling triplet loss: %0.5f"%alpha) 
    print("Y_pred(these are the embedded images) shape: %s"% y_pred.shape)
    print("Y true(these are the labels of the images )shape : %s"% y_true.shape)
    
    z=tf.cast(y_pred, tf.float64)   
    
    """
        1) calculate pairwise euclidean distances
        
        In the first step we calculate the pairwise euclidean distance from each embedded image to each other. If we had three images in a batch (z1,z2,z3)
    """
    z_row_norm = tf.reduce_sum(tf.pow(z,2), axis=1, keep_dims=True) # [batch_size, 1]
    squared_distances=tf.matmul(a=z,b=z,transpose_a=False,transpose_b=True) # => [batch_size, batch_size]
    squared_distances = -2 * squared_distances 
    squared_distances = squared_distances + z_row_norm # => broadcast as row vector 
    pw_sqrd_euclid_dists = tf.abs(squared_distances + tf.transpose(z_row_norm)) # => broadcast as column vector; use tf.abs because very small -0 floats 
    #pw_euclid_dists = tf.sqrt(pw_sqrd_euclid_dists)        
    
    """
        2) get pairwaise label equality
        
        In this step we calculate which of the true labels are equal to each other. 
    """ 
    y_row = tf.expand_dims(K.flatten(y_true), 0) # => [batch_size, 1 ]
    y_row_ary = tf.tile(y_row, [batch_size, 1])
    pw_label_equality = tf.cast(tf.equal(y_row_ary, tf.transpose(y_row_ary)), tf.int32)

    """
        3) Define all positive examples and all negative examples
        
        A positive example is if it has the same label as the anchor. 
        anchors are on the identity axis, so they are excluded   
    
    """
    # get all positive examples 
    positive_labels_cond = tf.not_equal(pw_label_equality, tf.eye(batch_size, dtype=tf.int32))
    positive_ed = tf.where(condition=positive_labels_cond , x=pw_sqrd_euclid_dists, y=tf.ones_like(pw_sqrd_euclid_dists)*-1)
    positive_ed = tf.add(positive_ed, tf.eye(batch_size, dtype=tf.float64)*-1) # exclude exclude self distance
    
    # get all negative examples
    negative_labels_cond = tf.equal(pw_label_equality, tf.zeros_like(pw_label_equality, dtype=tf.int32)) # get all negative examples 
    negative_ed = tf.where(condition=negative_labels_cond , x=pw_sqrd_euclid_dists, y=tf.ones_like(pw_sqrd_euclid_dists)*-1)
    negative_ed = tf.add(negative_ed, tf.eye(batch_size, dtype=tf.float64)*-1) # exclude self distances
    
    
    """
        4) Get all possible triplet permutations for each row of the batch. 
    """
    pos_row = tf.tile(tf.reshape(positive_ed, [-1, 1]), [1, batch_size])
    neg_col = tf.reshape(tf.tile(negative_ed, [1 , batch_size]), [-1, batch_size])
    
    """
        5) Select the ones that violate the triplet constraint 
    """
    # condition: exclude all invalid examples
    # we want: distance a=>n should be larger than the distance a=>p+margin
    # we want to catch examples where the distance a=>n-margin is smaller than the distance of the positive anchors
    neg_greater_zero = tf.greater_equal(neg_col, tf.zeros_like(neg_col)) # all the negative examples
    pos_greater_zero = tf.greater_equal(pos_row, tf.zeros_like(pos_row)) # permuted with all positive ones
    d_pos_less_than_d_neg = tf.less(x=neg_col-alpha, y=pos_row) # which violate distance anchor-positive <  anchor-negative
    hinge_loss = tf.maximum(pos_row-neg_col+alpha, 0) # loss calculation for all permutations
    permuations_loss  = tf.where(tf.logical_and(tf.logical_and(neg_greater_zero,d_pos_less_than_d_neg),pos_greater_zero),hinge_loss,tf.zeros_like(pos_row))
    # => shape [BATCH_SIZE*BATCH_SIZE, BATCH_SIZE]. This shape is because we only want all possible combination per row of the batch

    """
        6) Sum up 
    """    
    num_non_zero_perms = tf.reduce_sum(tf.cast(tf.greater(x=permuations_loss, y=tf.zeros_like(permuations_loss)), tf.float64))
    mean_permutation_loss = tf.reduce_sum(permuations_loss , axis=1) / num_non_zero_perms # only calculate mean between non-zero calculation losses, because 0 means invalid
    # => shape [BATCH_SIZE*BATCH_SIZE,1]
    per_example_loss = tf.reshape(mean_permutation_loss, [batch_size,batch_size]) # all valid permutations per example
    
    total_hinge_loss = tf.reduce_sum(per_example_loss)
    
    return tf.cast(total_hinge_loss, tf.float32)

In [ ]:

1.4 Model


In [6]:
# Convert class vectors to binary class matrices.
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(embedding_dim))
model.add(Activation('linear'))

# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.001, decay=1e-6)

1.5 Trainings procedure


In [7]:
# Let's train the model using RMSprop
model.compile(loss=triplet_loss,
              optimizer=opt,
              metrics=[])


compiling triplet loss: 0.80000
Y_pred(these are the embedded images) shape: (?, 32)
Y true(these are the labels of the images )shape : (?, ?)

2. Training

  • Don't know whether accuracy makes sense for evaluation
  • Don't know whether dropout makes sense

In [8]:
print(y_test.shape, x_test.shape)


(10000, 1) (10000, 32, 32, 3)

In [9]:
print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
    featurewise_center=False,  # set input mean to 0 over the dataset
    samplewise_center=False,  # set each sample mean to 0
    featurewise_std_normalization=False,  # divide inputs by std of the dataset
    samplewise_std_normalization=False,  # divide each input by its std
    zca_whitening=False,  # apply ZCA whitening
    rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
    width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
    height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
    horizontal_flip=True,  # randomly flip images
    vertical_flip=False)  # randomly flip images

# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).

datagen.fit(x_train)

# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
                    steps_per_epoch=int(np.ceil(x_train.shape[0] / float(batch_size))),
                    epochs=epochs,
        #validation_data=(x_test, y_test),
                    workers=1)

# Save model and weights
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
    
model_path = os.path.join(save_dir, model_name)
model.save(model_path)

print('Saved trained model at %s ' % model_path)


Using real-time data augmentation.
Epoch 1/30
500/500 [==============================] - 250s - loss: 0.8151 - val_loss: 0.7566
Epoch 2/30
500/500 [==============================] - 239s - loss: 0.8144 - val_loss: 0.7791
Epoch 3/30
500/500 [==============================] - 244s - loss: 0.8191 - val_loss: 0.7890
Epoch 4/30
500/500 [==============================] - 243s - loss: 0.8241 - val_loss: 0.7564
Epoch 5/30
500/500 [==============================] - 244s - loss: 0.8255 - val_loss: 0.7904
Epoch 6/30
500/500 [==============================] - 255s - loss: 0.8277 - val_loss: 0.8523
Epoch 7/30
500/500 [==============================] - 240s - loss: 0.8298 - val_loss: 0.7340
Epoch 8/30
500/500 [==============================] - 242s - loss: 0.8310 - val_loss: 0.7553
Epoch 9/30
500/500 [==============================] - 271s - loss: 0.8331 - val_loss: 0.8427
Epoch 10/30
500/500 [==============================] - 244s - loss: 0.8336 - val_loss: 0.7703
Epoch 11/30
500/500 [==============================] - 257s - loss: 0.8328 - val_loss: 0.8792
Epoch 12/30
500/500 [==============================] - 253s - loss: 0.8359 - val_loss: 0.8217
Epoch 13/30
500/500 [==============================] - 253s - loss: 0.8353 - val_loss: 0.7663
Epoch 14/30
500/500 [==============================] - 251s - loss: 0.8352 - val_loss: 0.8670
Epoch 15/30
500/500 [==============================] - 226s - loss: 0.8358 - val_loss: 0.8358
Epoch 16/30
500/500 [==============================] - 218s - loss: 0.8387 - val_loss: 0.8345
Epoch 17/30
500/500 [==============================] - 222s - loss: 0.8375 - val_loss: 0.8838
Epoch 18/30
500/500 [==============================] - 262s - loss: 0.8390 - val_loss: 0.8328
Epoch 19/30
500/500 [==============================] - 264s - loss: 0.8391 - val_loss: 0.7849
Epoch 20/30
500/500 [==============================] - 253s - loss: 0.8396 - val_loss: 0.8237
Epoch 21/30
500/500 [==============================] - 250s - loss: 0.8416 - val_loss: 0.7755
Epoch 22/30
500/500 [==============================] - 238s - loss: 0.8418 - val_loss: 0.8070
Epoch 23/30
500/500 [==============================] - 242s - loss: 0.8436 - val_loss: 0.8091
Epoch 24/30
500/500 [==============================] - 229s - loss: 0.8441 - val_loss: 0.8043
Epoch 25/30
500/500 [==============================] - 246s - loss: 0.8441 - val_loss: 0.7995
Epoch 26/30
500/500 [==============================] - 225s - loss: 0.8457 - val_loss: 0.8172
Epoch 27/30
500/500 [==============================] - 241s - loss: 0.8463 - val_loss: 0.7875
Epoch 28/30
500/500 [==============================] - 271s - loss: 0.8469 - val_loss: 0.8079
Epoch 29/30
500/500 [==============================] - 247s - loss: 0.8477 - val_loss: 0.8008
Epoch 30/30
500/500 [==============================] - 220s - loss: 0.8477 - val_loss: 0.7896
Saved trained model at /home/sthaler/Repositories/tf-spikes/triplet-loss/saved_models/keras_cifar10_trained_model.h5 

3 Model evaluation


In [ ]:
# https://stackoverflow.com/questions/33436221/displaying-rotatable-3d-plots-in-ipython-or-ipython-notebook

Embed all test images


In [10]:
embedded_test = np.ndarray((x_test.shape[0], embedding_dim), dtype="float32")

for i in range(int(x_test.shape[0]/batch_size )):
    start_idx = i * batch_size
    end_idx = (i+1) * batch_size
    embedded_test[start_idx:end_idx] = model.predict(x_test[start_idx:end_idx])

Cluster


In [11]:
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import v_measure_score
from sklearn.metrics import adjusted_mutual_info_score
num_classes=10
kmeans = KMeans(n_clusters=num_classes, random_state=0, n_init=20).fit(embedded_test)
kmeans.labels_

print("KMeans V-Measure", v_measure_score(labels_true=y_test.flatten(), labels_pred=kmeans.labels_))
print("KMeans AMI", adjusted_mutual_info_score(labels_true=y_test.flatten(), labels_pred=kmeans.labels_))


KMeans V-Measure 0.600570613955
KMeans AMI 0.599447185844

In [12]:
import sklearn
#np.seterr(divide='ignore', invalid='ignore')
silhouette = sklearn.metrics.silhouette_score(
    X=embedded_test, 
    labels=y_test.flatten(), 
    metric='euclidean')
print("Silhouette: %0.3f"%silhouette)


Silhouette: 0.224

Plot


In [13]:
# get a hex color range for number of parts
def get_N_HexCol(N=5):
    import colorsys # for get_N_HexCol
    HSV_tuples = [(x*1.0/N, 1, 1) for x in range(N)]
    hex_out = []
    for rgb in HSV_tuples:
        rgb = tuple(map(lambda x: int(x*255),colorsys.hsv_to_rgb(*rgb)))
        hex_out.append("#%.2X%.2X%.2X"%rgb )
    return hex_out

In [17]:
%matplotlib notebook 

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from sklearn import decomposition
from matplotlib.markers import MarkerStyle

pca = decomposition.PCA(n_components=3)
pca.fit(embedded_test)
test_reduced = pca.transform(embedded_test)

fig = plt.figure(figsize=(20,20))

ax = fig.add_subplot(111, projection='3d')                                  

colors=get_N_HexCol(N=10)
markers = list(MarkerStyle().markers.keys())
for i, x in enumerate(test_reduced[:1000]):
    class_id = y_test[i][0]
    ax.scatter(x[0], x[1], x[2], c=colors[class_id], marker=markers[class_id])

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()



In [ ]: